/**
 * Copyright 2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_NEON_H_
#define MINDSPORE_NNACL_FP32_SOFTMAX_NEON_H_

#include "nnacl/intrinsics/ms_simd_instructions.h"
#include "nnacl/intrinsics/ms_simd_neon_instructions.h"

#ifdef __cplusplus
extern "C" {
#endif

#define MS_SIMD_INSTRUCTION MS_SIMD_NEON_INSTRUCTION
#define BLOCK_NUM 4
#define MS_SIMD_NEON

static inline int64_t SoftmaxNormGetMaxNEON(int64_t index, const float *src, int cur_batch_offset,
  float *max, int channel) {
  if (channel >= BLOCK_NUM * BLOCK_NUM) {
    SIMD_F32 max_val = SIMD_MOV_F32(*max);
    for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
      max_val = SIMD_MAX_F32(max_val, SIMD_LD_F32(src + cur_batch_offset + index));
    }
    *max = SIMD_GET_MAX_F32(max_val);
  }
  return index;
}

static inline int64_t SoftmaxNormCalcNormNEON(int64_t index, const float *src, float *dst,
  int cur_batch_offset, float max, int channel) {
  for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 output = SIMD_SUB_F32(SIMD_LD_F32(src + cur_batch_offset + index), SIMD_MOV_F32(max));
    SIMD_ST_F32(dst + cur_batch_offset + index, output);
  }
  return index;
}

static inline int64_t SoftmaxLastAxisGetExpSumNEON(int64_t index, const float *src, float *dst,
  int cur_batch_offset, float max, float *exp_sum, int channel) {
#ifndef _WIN32
  SIMD_F32 sum_val = SIMD_SET0_F32;
  for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index);
    SIMD_F32 output = SIMD_SUB_F32(input, SIMD_MOV_F32(max));
    SIMD_F32 exp_out = SIMD_EXP_F32(output);
    sum_val = SIMD_ADD_F32(sum_val, exp_out);
    SIMD_ST_F32(dst + cur_batch_offset + index, exp_out);
  }
  *exp_sum += SIMD_GET_SUM_F32(sum_val);
#endif
  return index;
}

static inline int64_t SoftmaxLastAxisGetResultNEON(int64_t index, const float *src, float *dst,
  int cur_batch_offset, float exp_sum, int channel) {
  SIMD_F32 exp_sum_val = SIMD_MOV_F32(exp_sum);
  for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
    SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index);
    SIMD_F32 output = SIMD_MUL_F32(input, exp_sum_val);
    SIMD_ST_F32(dst + cur_batch_offset + index, output);
  }
  return index;
}

#undef MS_SIMD_INSTRUCTION
#undef BLOCK_NUM

#undef MS_SIMD_NEON
#ifdef __cplusplus
};
#endif
#endif
